import os
assert False == os.path.isdir('/app/data'), "Do not try to run this on solveit. The memory requirements will crash the VM."Inspect
import torch
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from midi_rae.vit import ViTEncoder, ViTDecoder
from midi_rae.swin import SwinEncoder, SwinDecoder
from midi_rae.data import PRPairDataset
from midi_rae.viz import make_emb_viz, viz_mae_recon
from midi_rae.utils import load_checkpoint
import matplotlib.pyplot as plt
# Interactive visualization (without wandb logging)
import plotly.io as pio
pio.renderers.default = 'notebook'
from midi_rae.viz import umap_project, pca_project, plot_embeddings_3d, make_emb_viz, viz_mae_reconConfig
#cfg = OmegaConf.load('../configs/config.yaml')
cfg = OmegaConf.load('../configs/config_swin.yaml')
#device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device = 'cpu' # leave GPU free for training while we do analysis here.
print(f'device = {device}')device = cpu
Load Dataset
val_ds = PRPairDataset(image_dataset_dir=cfg.data.path, split='val', max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y)
val_dl = DataLoader(val_ds, batch_size=cfg.training.batch_size, num_workers=4, drop_last=True)
print(f'Loaded {len(val_ds)} validation samples, batch_size = {cfg.training.batch_size}')Loading 91 val files from /home/shawley/datasets/POP909_images_basic...
Finished loading.
Loaded 9100 validation samples, batch_size = 380
Inspect Data
batch = next(iter(val_dl))
img1, img2, deltas, file_idx = batch['img1'].to(device), batch['img2'].to(device), batch['deltas'].to(device), batch['file_idx'].to(device)
print("img1.shape, deltas.shape, file_idx.shape =",tuple(img1.shape), tuple(deltas.shape), tuple(file_idx.shape))img1.shape, deltas.shape, file_idx.shape = (380, 1, 128, 128) (380, 2) (380,)
# Show a sample image pair
idx = 0
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(img1[idx, 0].cpu(), cmap='gray')
axes[0].set_title(f'Image 1 (file_idx={file_idx[idx].item()})')
axes[1].imshow(img2[idx, 0].cpu(), cmap='gray')
axes[1].set_title(f'Image 2 (deltas = {deltas[idx].cpu().int().numpy()})')
plt.tight_layout()
plt.show()
Load Encoder from Checkpoint
# if cfg.model.get('encoder', 'vit') == 'swin':
# model = ViTEncoder(cfg.data.in_channels, (cfg.data.image_size, cfg.data.image_size),
# cfg.model.patch_size, cfg.model.dim, cfg.model.depth, cfg.model.heads).to(device)
# ckpt_path = f'../checkpoints/{}__best.pt' # <-- change as needed
# ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
# state_dict = {k.replace('_orig_mod.', ''): v for k, v in ckpt['model_state_dict'].items()}
# model.load_state_dict(state_dict, strict=False)
if cfg.model.get('encoder', 'vit') == 'swin':
encoder = SwinEncoder(img_height=cfg.data.image_size, img_width=cfg.data.image_size,
patch_h=cfg.model.patch_h, patch_w=cfg.model.patch_w,
embed_dim=cfg.model.embed_dim, depths=cfg.model.depths,
num_heads=cfg.model.num_heads, window_size=cfg.model.window_size,
mlp_ratio=cfg.model.mlp_ratio, drop_path_rate=cfg.model.drop_path_rate).to(device)
else:
encoder = ViTEncoder(cfg.data.in_channels, cfg.data.image_size, cfg.model.patch_size,
cfg.model.dim, cfg.model.depth, cfg.model.heads).to(device)
encoder = load_checkpoint(encoder, cfg.get('encoder_ckpt', f'../checkpoints/{encoder.__class__.__name__}__best.pt'))
encoder.eval()
print(f"Loaded {encoder.__class__.__name__}")>>> Loaded model checkpoint from ../checkpoints/SwinEncoder__best.pt
Loaded SwinEncoder
Run Batch Through Encoder
with torch.no_grad():
enc_out1 = encoder(img1)
enc_out2 = encoder(img2)
# z1 = enc_out1.patches.all_emb.reshape(-1, enc_out1.patches[1].dim)
# z2 = enc_out2.patches.all_emb.reshape(-1, enc_out2.patches[1].dim)
# num_tokens = enc_out1.patches.all_emb.shape[1]
# print(f'z1: {z1.shape}, z2: {z2.shape}, num_tokens: {num_tokens}')Visualize Embeddings
NOTE: This will visualize all embeddings in the entire batch, not just the single pair of images shown above.
figs = make_emb_viz((enc_out1, enc_out2), encoder=encoder, batch=batch, do_umap=False)
figs.keys() # show what figures are availabledict_keys(['cls_pca_fig', 'cls_umap_fig', 'patch_pca_fig', 'patch_umap_fig', 'empty_pca_fig'])
Next code cell reads:
figs['cls_pca_fig'].show()Make sure the next code cell is hidden or else the plotly.js will swamp the LLM context.
figs['cls_pca_fig'].show()Note how the CLS tokens are nicely grouped in pairs. Let’s see if the same is true for the randomly-sampled pairs of non-empty patch embeddings 🤞:
Next code cell reads:
figs['patch_pca_fig'].show()Make sure the next code cell is hidden or else the plotly.js will swamp the LLM context.
figs['patch_pca_fig'].show()SVD Analysis
def svd_analysis(enc_out, level=1, title='', top_k=20):
"Run SVD on encoder output, plot singular value spectrum and cumulative variance"
z = enc_out.patches[level].emb.detach().cpu().float().reshape(-1, enc_out.patches[level].dim) # flatten batch
z = z - z.mean(dim=0) # center
U, S, Vt = torch.linalg.svd(z, full_matrices=False) # Vt for "V transpose" (technically it's "V hermitian" but we've got real data)
var_exp = (S**2) / (S**2).sum()
cum_var = var_exp.cumsum(0)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
ax1.semilogy(S.numpy()); ax1.axvline(x=top_k, color='r', ls='--', alpha=0.5)
ax1.set(xlabel='Component', ylabel='Singular value', title=f'{title} Singular Values')
ax2.bar(range(top_k), var_exp[:top_k].numpy())
ax2.set(xlabel='Component', ylabel='Variance explained', title=f'{title} Top {top_k} Variance')
ax3.plot(cum_var.numpy()); ax3.axhline(y=0.9, color='r', ls='--', alpha=0.5, label='90%')
ax3.set(xlabel='Component', ylabel='Cumulative variance', title=f'{title} Cumulative Variance')
ax3.legend()
plt.tight_layout(); plt.show()
n90 = (cum_var < 0.9).sum().item() + 1
print(f"{title}: {n90} components for 90% variance, top-1 explains {var_exp[0]:.1%}")
return S, U, Vt, var_expS, U, Vt, var_exp = svd_analysis(enc_out2, title='Patches')
Patches: 7 components for 90% variance, top-1 explains 59.3%
cls_S, cls_U, cls_Vt, cls_var_exp = svd_analysis(enc_out2, level=0, title='CLS')
CLS: 2 components for 90% variance, top-1 explains 85.9%
Two key takeaways:
- Patches need 178/256 dims for 90%. The representation is highly distributed with no dominant direction. This means the encoder is using nearly all its capacity, which is healthy (no dimensional collapse). But it also suggests rhythm and pitch aren’t cleanly factored — if they were, you’d expect a sharper elbow in the spectrum (the first 1 or 2 components notwithstanding).
- CLS only needs 23/256 dims. The global summary is much more compressed. That’s interesting for generation: it suggests the “gist” of a musical passage lives in a ~23-dimensional subspace. The gradual slope in the top-20 bars (no single dominant component) means it’s not collapsing to a trivial representation either.
Decoder Performance
if cfg.model.get('encoder', 'vit') == 'swin': # decoder should match encoder
decoder = SwinDecoder(img_height=cfg.data.image_size, img_width=cfg.data.image_size,
patch_h=cfg.model.patch_h, patch_w=cfg.model.patch_w,
embed_dim=cfg.model.embed_dim,
depths=cfg.model.get('dec_depths', cfg.model.depths),
num_heads=cfg.model.get('dec_num_heads', cfg.model.num_heads)).to(device)
else:
decoder = ViTDecoder(cfg.data.in_channels, (cfg.data.image_size, cfg.data.image_size),
cfg.model.patch_size, cfg.model.dim,
cfg.model.get('dec_depth', 4), cfg.model.get('dec_heads', 8)).to(device)
name = decoder.__class__.__name__
print("Name = ",name)
decoder = load_checkpoint(decoder, cfg.get('encoder_ckpt', f'../checkpoints/{decoder.__class__.__name__}__best.pt'))Name = SwinDecoder
>>> Loaded model checkpoint from ../checkpoints/SwinDecoder__best.pt
recon_logits = decoder(enc_out2)
img_recon = torch.sigmoid(recon_logits)
img_real = img2
print("img_recon.shape, img_real.shape =",img_recon.shape, img_real.shape)img_recon.shape, img_real.shape = torch.Size([380, 1, 128, 128]) torch.Size([380, 1, 128, 128])
grid_recon, grid_real, grid_map, evals = viz_mae_recon(img_recon, img_real, enc_out=None, epoch=0, debug=False, return_maps=True)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 5))
ax1.imshow(grid_real.permute(1,2,0), cmap='gray'); ax1.set_title('Real')
ax2.imshow(grid_recon.permute(1,2,0), cmap='gray'); ax2.set_title('Recon')
ax3.imshow(grid_map.permute(1,2,0)); ax3.set_title('Map')
plt.show()
print(', '.join(f"{k}: {v.item():.4f}" for k, v in evals.items() if not k.endswith('map')))
precision: 0.9990, recall: 0.9998, specificity: 1.0000, f1: 0.9994
Wow! F1 = 0.9995!
That’s nearly perfect reconstruction: F1 accuracy of 99.95% Seems we have our representation autoencoder!
Let’s Show the map image really big. It’s designed to show red pixels wherever there are False Positives and yellow pixels wherever there are False Negatives (and white = True Pos, black = True Neg)…
I don’t see any red or yellow, do you?
In the next cell we’re gonna plot an image showing the maps as a very large image, we’re gonna hide it from the LLM because it doesn’t need to see it and we wanna spare the context.
from PIL import Image
from IPython.display import display
img = Image.fromarray((grid_map*255).permute(1,2,0).byte().numpy())
display(img)